{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp data.mixed_augmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Label-mixing transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Callbacks that perform data augmentation by mixing samples in different ways."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from torch.distributions.beta import Beta\n",
    "from fastai.callback.core import Callback\n",
    "from fastai.layers import NoneReduce\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _reduce_loss(loss, reduction='mean'):\n",
    "    \"Reduce the loss based on `reduction`\"\n",
    "    return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MixHandler1d(Callback):\n",
    "    \"A handler class for implementing mixed sample data augmentation\"\n",
    "    run_valid = False\n",
    "\n",
    "    def __init__(self, alpha=0.5):\n",
    "        self.distrib = Beta(alpha, alpha)\n",
    "\n",
    "    def before_train(self):\n",
    "        if not self.training: return\n",
    "        b = self.learn.dls.one_batch()\n",
    "        self.labeled = len(b) > 1\n",
    "        if self.labeled:\n",
    "            self.stack_y = getattr(self.learn.loss_func, 'y_int', False)\n",
    "            if self.stack_y: self.old_lf, self.learn.loss_func = self.learn.loss_func, self.lf\n",
    "\n",
    "    def after_train(self):\n",
    "        if self.labeled and self.stack_y: self.learn.loss_func = self.old_lf\n",
    "\n",
    "    def lf(self, pred, *yb):\n",
    "        if not self.training: return self.old_lf(pred, *yb)\n",
    "        with NoneReduce(self.old_lf) as lf: loss = torch.lerp(lf(pred, *self.yb1), lf(pred, *yb), self.lam)\n",
    "        return _reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MixUp1d(MixHandler1d):\n",
    "    \"Implementation of https://arxiv.org/abs/1710.09412\"\n",
    "\n",
    "    def __init__(self, alpha=.4):\n",
    "        super().__init__(alpha)\n",
    "\n",
    "    def before_batch(self):\n",
    "        if not self.training: return\n",
    "        lam = self.distrib.sample((self.x.size(0), ))\n",
    "        self.lam = torch.max(lam, 1 - lam).to(self.x.device)\n",
    "        shuffle = torch.randperm(self.x.size(0))\n",
    "        xb1 = self.x[shuffle]\n",
    "        self.learn.xb = L(xb1, self.xb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.x.ndim - 1))\n",
    "        if self.labeled:\n",
    "            self.yb1 = tuple((self.y[shuffle], ))\n",
    "            if not self.stack_y: self.learn.yb = L(self.yb1, self.yb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.y.ndim - 1))    \n",
    "                \n",
    "MixUp1D = MixUp1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.learner import *\n",
    "from tsai.models.InceptionTime import *\n",
    "from tsai.data.external import get_UCR_data\n",
    "from tsai.data.core import get_ts_dls, TSCategorize\n",
    "from tsai.data.preprocessing import TSStandardize\n",
    "from tsai.learner import ts_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.908455</td>\n",
       "      <td>1.811908</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X, y, splits = get_UCR_data('NATOPS', return_split=False)\n",
    "tfms = [None, TSCategorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)\n",
    "learn = ts_learner(dls, InceptionTime, cbs=MixUp1d(0.4))\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class CutMix1d(MixHandler1d):\n",
    "    \"Implementation of `https://arxiv.org/abs/1905.04899`\"\n",
    "\n",
    "    def __init__(self, alpha=1.):\n",
    "        super().__init__(alpha)\n",
    "            \n",
    "    def before_batch(self):\n",
    "        if not self.training: return\n",
    "        bs, *_, seq_len = self.x.size()\n",
    "        lam = self.distrib.sample((1, ))\n",
    "        self.lam = torch.max(lam, 1 - lam).to(self.x.device)\n",
    "        shuffle = torch.randperm(bs)\n",
    "        xb1 = self.x[shuffle]\n",
    "        x1, x2 = self.rand_bbox(seq_len, self.lam)\n",
    "        self.learn.xb[0][..., x1:x2] = xb1[..., x1:x2]\n",
    "        self.lam = (1 - (x2 - x1) / float(seq_len)).item()\n",
    "        if self.labeled:\n",
    "            self.yb1 = tuple((self.y[shuffle], ))\n",
    "            if not self.stack_y:\n",
    "                self.learn.yb = tuple(L(self.yb1, self.yb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.y.ndim - 1)))\n",
    "\n",
    "    def rand_bbox(self, seq_len, lam):\n",
    "        cut_seq_len = torch.round(seq_len * (1. - lam)).type(torch.long)\n",
    "        half_cut_seq_len = torch.div(cut_seq_len, 2, rounding_mode='floor')\n",
    "\n",
    "        # uniform\n",
    "        cx = torch.randint(0, seq_len, (1, ), device=lam.device)\n",
    "        x1 = torch.clamp(cx - half_cut_seq_len, 0, seq_len)\n",
    "        x2 = torch.clamp(cx + half_cut_seq_len, 0, seq_len)\n",
    "        return x1, x2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class IntraClassCutMix1d(Callback):\n",
    "    \"Implementation of CutMix applied to examples of the same class\"\n",
    "    run_valid = False\n",
    "\n",
    "    def __init__(self, alpha=1.):\n",
    "        self.distrib = Beta(tensor(alpha), tensor(alpha))\n",
    "\n",
    "    def before_batch(self):\n",
    "        if not self.training: return\n",
    "        bs, *_, seq_len = self.x.size()\n",
    "        idxs = torch.arange(bs, device=self.x.device)\n",
    "        y = torch.tensor(self.y)\n",
    "        unique_c = torch.unique(y).tolist()\n",
    "        idxs_by_class = torch.cat([idxs[torch.eq(y, c)] for c in unique_c])\n",
    "        idxs_shuffled_by_class = torch.cat([random_shuffle(idxs[torch.eq(y, c)]) for c in unique_c])\n",
    "        self.lam = self.distrib.sample((1, )).to(self.x.device)\n",
    "        x1, x2 = self.rand_bbox(seq_len, self.lam)\n",
    "        xb1 = self.x[idxs_shuffled_by_class]\n",
    "        self.learn.xb[0][idxs_by_class, :, x1:x2] = xb1[..., x1:x2]\n",
    "\n",
    "    def rand_bbox(self, seq_len, lam):\n",
    "        cut_seq_len = torch.round(seq_len * (1. - lam)).type(torch.long)\n",
    "        half_cut_seq_len = torch.div(cut_seq_len, 2, rounding_mode='floor')\n",
    "\n",
    "        # uniform\n",
    "        cx = torch.randint(0, seq_len, (1, ), device=lam.device)\n",
    "        x1 = torch.clamp(cx - half_cut_seq_len, 0, seq_len)\n",
    "        x2 = torch.clamp(cx + half_cut_seq_len, 0, seq_len)\n",
    "        return x1, x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.813483</td>\n",
       "      <td>1.792010</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X, y, splits = get_UCR_data('NATOPS', split_data=False)\n",
    "tfms = [None, TSCategorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)\n",
    "learn = ts_learner(dls, InceptionTime, cbs=IntraClassCutMix1d())\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.824509</td>\n",
       "      <td>1.774964</td>\n",
       "      <td>00:04</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X, y, splits = get_UCR_data('NATOPS', split_data=False)\n",
    "tfms = [None, TSCategorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)\n",
    "learn = ts_learner(dls, cbs=CutMix1d(1.))\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/011_data.mixed_augmentation.ipynb couldn't be saved automatically. You should save it manually 👋\n",
      "Correct notebook to script conversion! 😃\n",
      "Tuesday 20/06/23 14:18:05 CEST\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
